#!/usr/bin/env python
# encoding: utf-8
'''
@author: wangjianrong
@software: pycharm
@file: 混合精度.py
@time: 2020/10/13 17:21
@desc:
https://fyubang.com/2019/08/26/fp16/
'''

'''
apex安装方法
git clone https://github.com/NVIDIA/apex.git
cd apex
python setup.py install --cpp_ext --cuda_ext
'''

from apex import amp

model, optimizer = amp.initialize(model, optimizer, opt_level="O1") # 这里是“欧一”，不是“零一”
with amp.scale_loss(loss, optimizer) as scaled_loss:
    scaled_loss.backward()